import os
import re
import importlib
import subprocess
from datetime import datetime
from typing import TypedDict

from django.core.management.base import BaseCommand, CommandError

import requests
import textcase
from git import Repo
from structlog import get_logger

from products.data_warehouse.backend import types

os.environ["DEBUG"] = "1"
os.environ["SKIP_ASYNC_MIGRATIONS_SETUP"] = "1"


SOURCE_TEMPLATE = """\
from typing import cast

from posthog.schema import (
    ExternalDataSourceType as SchemaExternalDataSourceType,
    SourceConfig,
)

from posthog.temporal.data_imports.pipelines.pipeline.typings import SourceInputs, SourceResponse
from posthog.temporal.data_imports.sources.common.base import BaseSource, FieldType
from posthog.temporal.data_imports.sources.common.config import Config
from posthog.temporal.data_imports.sources.common.registry import SourceRegistry
from posthog.temporal.data_imports.sources.common.schema import SourceSchema
from products.data_warehouse.backend.types import ExternalDataSourceType

# TODO({git_user}): implement the source logic for {pascal}Source


@SourceRegistry.register
class {pascal}Source(BaseSource[Config]):
    @property
    def source_type(self) -> ExternalDataSourceType:
        return ExternalDataSourceType.{caps}

    @property
    def get_source_config(self) -> SourceConfig:
        return SourceConfig(
            name=SchemaExternalDataSourceType.{constant},
            iconPath="/static/services/{kebab}.png",
            label="{pascal}",  # only needed if the readable name is complex. delete otherwise
            caption=None,  # only needed if you want to inline docs
            docsUrl=None,  # TODO({git_user}): link to the docs in the website, full path including https://
            fields=cast(list[FieldType], []), # TODO({git_user}): add source config fields here
            unreleasedSource=True,
        )

    def validate_credentials(self, config: Config, team_id: int) -> tuple[bool, str | None]:
        # TODO({git_user}): implement the logic to validate the credentials of your source,
        # e.g. check the validity of API keys. returns a tuple of whether the credentials are valid,
        # and if not, returns an error message to return to the user
        raise NotImplementedError()

    def get_schemas(self, config: Config, team_id: int, with_counts: bool = False) -> list[SourceSchema]:
        raise NotImplementedError()

    def source_for_pipeline(self, config: Config, inputs: SourceInputs) -> SourceResponse:
        raise NotImplementedError()
"""

MIGRATION_TEMPLATE = """\
# Generated by create_datastack_source command on {timestamp}

from django.db import migrations, models


class Migration(migrations.Migration):
    dependencies = [
        ("posthog", "{max_migration}"),
    ]

    operations = [
        migrations.AlterField(
            model_name="externaldatasource",
            name="source_type",
            field=models.CharField(
                choices={choices},
                max_length=128,
            ),
        ),
    ]
"""

logger = get_logger(__name__)


class NameTransforms(TypedDict):
    pascal: str
    snake: str
    kebab: str
    constant: str
    caps: str


class Command(BaseCommand):
    help = "Create a new unreleased data warehouse source"

    def add_arguments(self, parser):
        parser.add_argument(
            "--name",
            type=str,
            help="Name of the external data source (e.g., Stripe, Meta Ads)",
        )
        parser.add_argument(
            "--site",
            type=str,
            help="A URL for retrieving the source logo in the form domain.com (e.g., shopify.com, stripe.com)",
        )

    def handle(self, *args, **options):
        name = options.get("name")
        if not name:
            name = input("What source are you scaffolding? (e.g. Stripe, Meta Ads): ").strip()
        if not name:
            raise CommandError("You entered an empty name for this source. Aborting...")

        site_url = options.get("site")
        if not site_url:
            site_url = input(
                "What site can we use to retrieve the logo image? (e.g., stripe.com) Press enter to skip: "
            ).strip()

        repo = Repo(".", search_parent_directories=True)

        name_transforms: NameTransforms = {
            "pascal": textcase.pascal(name),
            "snake": textcase.snake(name),
            "kebab": textcase.kebab(name),
            "constant": textcase.constant(name),
            "caps": textcase.constant(name).replace("_", ""),
        }
        self._fix_common_endings(transforms=name_transforms)
        logo_filename = f"{name_transforms['kebab']}.png"

        self._setup_source_structure(repo, transforms=name_transforms)
        self._pull_logo_png(repo, url=site_url, filename=logo_filename)

        self._add_warehouse_types_enum(repo, transforms=name_transforms)
        self._add_schema_py_enum(repo, transforms=name_transforms)
        self._add_schema_general_ts_list_item(repo, transforms=name_transforms)
        self._update_sources_init(repo, transforms=name_transforms)

        if self._has_pending_migrations(repo):
            self.stdout.write(
                self.style.WARNING(
                    "The max_migration.txt file is modified in the working tree. Skipping makemigrations..."
                )
            )
        else:
            self._migrate(repo)

        self._generate_source_configs(transforms=name_transforms)
        self._update_config_references(repo, transforms=name_transforms)
        self._schema_build(transforms=name_transforms)
        self._format_files()

    def _fix_common_endings(self, transforms: NameTransforms):
        common_endings = ("Io", "Ai", "Db", "Ci")
        for end in common_endings:
            if transforms["pascal"].endswith(end):
                transforms["pascal"] = transforms["pascal"][: -len(end)] + end.upper()

    def _setup_source_structure(self, repo: Repo, transforms: NameTransforms):
        sources_root = os.path.join(repo.working_dir, "posthog", "temporal", "data_imports", "sources")
        assert os.path.exists(sources_root), f"Sources root {sources_root} not found"

        source_dir = os.path.join(sources_root, transforms["snake"])
        if not os.path.exists(source_dir):
            os.mkdir(source_dir)
            self.stdout.write(self.style.SUCCESS(f"Created directory: {source_dir}"))
        else:
            self.stdout.write(self.style.WARNING(f"Directory exists: {source_dir}"))

        git_user = str(repo.config_reader().get_value("user", "name"))
        starter_template = SOURCE_TEMPLATE.format(git_user=git_user, **transforms)

        source_file = os.path.join(source_dir, "source.py")
        if not os.path.exists(source_file):
            with open(source_file, "w") as f:
                f.write(starter_template)
            self.stdout.write(self.style.SUCCESS(f"Created file: {source_file}"))
        else:
            self.stdout.write(self.style.WARNING(f"File exists: {source_file}"))

    def _pull_logo_png(self, repo: Repo, url: str | None, filename: str):
        path = os.path.join(repo.working_dir, "frontend", "public", "services", filename)
        if os.path.exists(path):
            self.stdout.write(self.style.WARNING(f"Logo image {path} already exists. Skipping..."))
            return
        warn_msg = "You will need to supply a logo image file in frontend/public/services manually..."
        if url:
            publishable_key = "pk_ObrK3F_LS8u7MUbD1itCTA"
            logo_dev_url = f"https://img.logo.dev/{url}?token={publishable_key}&size=256&format=png"
            res = requests.get(logo_dev_url, stream=True)
            if res.status_code == 200:
                with open(path, "wb") as f:
                    for chunk in res.iter_content(chunk_size=8 * 1024):
                        if chunk:
                            f.write(chunk)
                return
            warn_msg = f"Failed to retrieve logo for site {url} from logo.dev. " + warn_msg
        else:
            warn_msg = "No site was provided to retrieve a logo image. " + warn_msg
        self.stdout.write(self.style.WARNING(warn_msg))

    def _split_file_by_regex(self, file: str, regex: str) -> tuple[str, str]:
        """Returns file contents pre and post a regex match (pre is inclusive of the regex match)"""
        assert os.path.exists(file), f"File not found: {file}"

        with open(file) as f:
            content = f.read()
        match = re.search(regex, content, re.MULTILINE)
        assert match, f"File {file} no longer conforms to the expected format for the create_warehouse_source command"

        insert_idx = match.end()
        return content[:insert_idx], content[insert_idx:]

    def _entry_exists_in_contiguous_text_block(self, entry: str, block: str) -> bool:
        for line in block.split("\n"):
            if not line.strip():
                break
            if entry in line.strip():
                return True
        return False

    def _format_file_line(self, line: str, indent_level: int = 1, end: str = "\n"):
        indent_spaces = 4 * indent_level * " "
        line = indent_spaces + line
        if not line.endswith(end):
            line += end
        return line

    def _has_pending_migrations(self, repo: Repo):
        unstaged_changes = [item.a_path for item in repo.index.diff(None) if item.a_path]
        unstaged_changes += [item.b_path for item in repo.index.diff(None) if item.b_path]
        staged_changes = [item.a_path for item in repo.index.diff("HEAD") if item.a_path]
        staged_changes += [item.b_path for item in repo.index.diff("HEAD") if item.b_path]
        migration_file = "max_migration.txt"
        for file_path in unstaged_changes:
            if migration_file in file_path:
                return True
        for file_path in staged_changes:
            if migration_file in file_path:
                return True
        return False

    def _add_warehouse_types_enum(self, repo: Repo, transforms: NameTransforms):
        file = os.path.join(repo.working_dir, "posthog", "warehouse", "types.py")
        assert os.path.exists(file), f"File not found {file}"

        key, val = transforms["caps"], transforms["pascal"]
        regex = r"(^class ExternalDataSourceType\(models\.TextChoices\)\:\n)"
        pre, post = self._split_file_by_regex(file, regex)
        if self._entry_exists_in_contiguous_text_block(entry=key, block=post):
            self.stdout.write(self.style.WARNING(f"Source entry already exists in {file}. Skipping..."))
            return

        line = self._format_file_line(f'{key} = "{val}", "{val}"')
        with open(file, "w") as f:
            f.write("".join([pre, line, post]))
        self.stdout.write(self.style.SUCCESS(f"Added source entry to {file}..."))

    def _add_schema_py_enum(self, repo: Repo, transforms: NameTransforms):
        file = os.path.join(repo.working_dir, "posthog", "schema.py")
        assert os.path.exists(file), f"File not found {file}"

        key, val = transforms["constant"], transforms["pascal"]
        regex = r"(^class ExternalDataSourceType\(StrEnum\)\:\n)"
        pre, post = self._split_file_by_regex(file, regex)
        if self._entry_exists_in_contiguous_text_block(entry=key, block=post):
            self.stdout.write(self.style.WARNING(f"Source entry already exists in {file}. Skipping..."))
            return

        line = self._format_file_line(f'{key} = "{val}"')
        with open(file, "w") as f:
            f.write("".join([pre, line, post]))
        self.stdout.write(self.style.SUCCESS(f"Added source entry to {file}..."))

    def _add_schema_general_ts_list_item(self, repo: Repo, transforms: NameTransforms):
        file = os.path.join(repo.working_dir, "frontend", "src", "queries", "schema", "schema-general.ts")
        assert os.path.exists(file), f"File not found {file}"

        val = transforms["pascal"]
        regex = r"(^export const externalDataSources = \[\n)"
        pre, post = self._split_file_by_regex(file, regex)
        if self._entry_exists_in_contiguous_text_block(entry=val, block=post):
            self.stdout.write(self.style.WARNING(f"Source entry already exists in {file}. Skipping..."))
            return

        line = self._format_file_line(f'"{val}",')
        with open(file, "w") as f:
            f.write("".join([pre, line, post]))
        self.stdout.write(self.style.SUCCESS(f"Added source entry to {file}..."))

    def _update_sources_init(self, repo: Repo, transforms: NameTransforms):
        file = os.path.join(repo.working_dir, "posthog", "temporal", "data_imports", "sources", "__init__.py")
        assert os.path.exists(file), f"File not found {file}"

        line = f"from .{transforms['snake']}.source import {transforms['pascal']}Source\n"
        with open(file) as f:
            content = f.read()
        with open(file, "w") as f:
            f.write(line + content)

        val = f"{transforms['pascal']}Source"
        regex = r"(^__all__ = \[\n)"
        pre, post = self._split_file_by_regex(file, regex)
        if self._entry_exists_in_contiguous_text_block(entry=val, block=post):
            self.stdout.write(self.style.WARNING(f"Source entry already exists in {file}. Skipping..."))
            return
        line = self._format_file_line(f'"{val}",')
        with open(file, "w") as f:
            f.write("".join([pre, line, post]))
        self.stdout.write(self.style.SUCCESS(f"Added source entry to {file}..."))

    def _migrate(self, repo: Repo):
        importlib.reload(types)  # reload types to include our modifications

        migrations_dir = os.path.join(repo.working_dir, "posthog", "migrations")
        assert os.path.exists(migrations_dir), "Migrations dir not found. Yikes..."

        with open(os.path.join(migrations_dir, "max_migration.txt")) as f:
            max_migration = f.read().strip()

        migration_num = int(max_migration.split("_")[0]) + 1
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
        choices = types.ExternalDataSourceType.choices
        migration_name = f"{migration_num:04d}_alter_externaldatasource_source_type.py"

        with open(os.path.join(migrations_dir, migration_name), "w") as f:
            f.write(MIGRATION_TEMPLATE.format(timestamp=timestamp, max_migration=max_migration, choices=choices))

        with open(os.path.join(migrations_dir, "max_migration.txt"), "w") as f:
            f.write(migration_name.removesuffix(".py"))

    def _generate_source_configs(self, transforms: NameTransforms):
        try:
            subprocess.run(["pnpm", "run", "generate:source-configs"], check=True)
            self.stdout.write(self.style.SUCCESS(f"Generated source config for {transforms['pascal']}..."))
        except subprocess.CalledProcessError as e:
            self.stdout.write(self.style.ERROR(f"Failed to generate source configs with error: {e}"))

    def _update_config_references(self, repo: Repo, transforms: NameTransforms):
        file = os.path.join(
            repo.working_dir, "posthog", "temporal", "data_imports", "sources", transforms["snake"], "source.py"
        )
        assert os.path.exists(file), f"File {file} not found..."

        with open(file) as f:
            content = f.read()

        pascal = transforms["pascal"]
        content = content.replace(
            "from posthog.temporal.data_imports.sources.common.config import Config",
            f"from posthog.temporal.data_imports.sources.generated_configs import {pascal}SourceConfig",
        )
        new_config = f"{pascal}SourceConfig"
        regex = r"(?<!Source)Config"
        content = re.sub(regex, new_config, content)

        with open(file, "w") as f:
            f.write(content)

        self.stdout.write(self.style.SUCCESS(f"Updated config references in {file}"))

    def _schema_build(self, transforms: NameTransforms):
        try:
            subprocess.run(["pnpm", "run", "schema:build"], check=True)
            self.stdout.write(self.style.SUCCESS(f"Built schema for {transforms['pascal']}..."))
        except subprocess.CalledProcessError as e:
            self.stdout.write(self.style.ERROR(f"Failed to build schema with error: {e}"))

    def _format_files(self):
        try:
            subprocess.run(["ruff", "format"], check=True)
            self.stdout.write(self.style.SUCCESS("Formatted files. Done."))
        except subprocess.CalledProcessError as e:
            self.stdout.write(self.style.ERROR(f"Failed to format files with error: {e}"))
